// HuginApiAddonsTesters.cpp : main project file.

#include "stdafx.h"
#include <ExpandTimeSteps.h>
#include "hugin"
#include <iostream>
#include "..\HuginApiAddons\PolicyFileWriter.h"
#include <List>
#include "..\HuginApiAddons\iDIDSolver.h"
#include "..\HuginApiAddons\Simulator.h"
#include "..\HuginApiAddons\IdidEnterModels.h"
#include "..\HuginApiAddons\PolicyTreeBuilder.h"

#include "ActionNode.h"
#include <list>
#include <iostream>
#include <windows.h>
#include "BuilderCSV.h"
#include "FilterCSV.h"
#include <fstream>
#include "hugin"
#include "inputPprIdid.h"
#include <boost/property_tree/ptree.hpp>
#include <boost/property_tree/xml_parser.hpp>
#include "../HuginApiAddons/csvCalculations.h"


using boost::property_tree::ptree;
using boost::property_tree::write_xml;
using boost::property_tree::xml_writer_settings;


using namespace System;
using namespace HAPI;
using namespace std;

Domain * testExpandDomain(string inputFile, string outputFile, ExpandTimeSteps::IdType idType, int steps);
void testPolicyWriter(string inputFile, string actionPrefix, string observationPrefix, string outputFolder, string fileExt);
void testPolicyWriter(Domain * inputID, string actionPrefix, string observationPrefix, string outputFolder, string fileExt);
void testIDIDSolver(string inputFile, string model1File, string model2File, int numSteps);
void testSimulatorDID(list<string> inputFiles, string statePrefix, string observationPrefix, string actionPrefix, string utilityPrefix, int numberOfSimulations, string outputFile);
void testSimulatorIDID(string inputIDID, list<string> inputDIDs, string statePrefix, string utilityPrefix, string iObservationPrefix, string jObservationPrefix, string iActionPrefix, string jActionPRefix, int numberOfSimulations, string outputFile, bool modelWeight, string modelPrefix);
Domain * testExpandIDIDModels(string inputIDID, list<string> inputDIDs, string statePrefix, string utilityPrefix, string iObservationPrefix, string jObservationPrefix, string iActionPrefix, string jActionPRefix, string modelPrefix);
PolicyNode * testDomainToPolicyTree(string inputDID, string actionPrefix, string observationPrefix);
hash_map<string, PPR> testLearnPolicy(string inputCSV, vector<string> colHeaders, vector<string> observations, vector<string> actions, vector<string> states, int treeLength, string inputNet, string outputNet, vector<string> prefixes, string outputFilledNet, double epsilon, string mergesCSVfile, vector<string> joinStates);
vector<string> argumentToVector(string argument);

int main(int argc, char* argv[])
{
	if(argc < 2)
	{
		cout << "USEAGE: [program] [program arguments]" << endl;
		cout << "programs" << endl;
		cout << "simulator [output csv file] [prefixes s,o,a,u] [num simulations] [DID] [models net files]" << endl;
		cout << "simulator [output csv file] [prefixes s,u,oi,oj,ai,aj] [num simulations] [I-DID] [I-DID Net File] [models net files]" << endl;
		cout << "expandIDID [input net file] [output net file] [prefixes s,u,oi,oj,ai,aj,mod] [models net files]" << endl;
		cout << "expand [input net file] [output net file] [Domain Type] [time steps]" << endl;
		cout << "learnPolicy [input csv file] [column headers Model,Observation,Action,State] [observations] [Actions] [States] [input net file] [output net file] [prefixes aj,mod] [tree length] <filled in net file>" << endl;
		cout << "calculateFrequencies [mode file/folder] [file/folder] [column header]" << endl;
		cout << "mergeFiles [input folder] [chars] [output folder]" << endl;
		cout << "policyWriter [input file] [action prefix] [observation prefix] [output folder] [output file extension]" << endl;
	}
	else
	{
		string program = argv[1];
		//Simulator
		if(program == "simulator")
		{
			if(argc < 6)
			{
				cout << "USEAGE: simulator [output csv file] [prefixes s,o,a,u] [num simulations] [DID] [models net files]" << endl;
				cout << "USEAGE: simulator [output csv file] [prefixes s,u,oi,oj,ai,aj] [num simulations] [I-DID] [I-DID Net File] [models net files]" << endl;
			}
			else
			{
				//cout << "Simulator" << endl;

				string csvFile = argv[2];
				//ToDo 
				string prefixesStr = argv[3];
				int numSimulations = atoi(argv[4]);
				string domainType = argv[5];
				list<string> modelsNetFiles;

				vector<string> prefixes = argumentToVector(prefixesStr);

				if(domainType == "DID")
				{					
					for(int i = 6; i < argc; i++)
					{
						modelsNetFiles.push_back(argv[i]);
					}
					testSimulatorDID(modelsNetFiles, prefixes[0], prefixes[1], prefixes[2], prefixes[3], numSimulations, csvFile);
				}
				else if (domainType == "I-DID")
				{
					string domainNetFile = argv[6];
					for(int i = 7; i < argc; i++)
					{
						modelsNetFiles.push_back(argv[i]);
					}
					testSimulatorIDID(domainNetFile, modelsNetFiles, prefixes[0], prefixes[1], prefixes[2], prefixes[3], prefixes[4], prefixes[5], numSimulations, csvFile, true, "mod");
				}
				
				//cout << "Simulator Complete" << endl;
			}
		}

		//Expand I-DID models
		if(program == "expandIDID")
		{
			//cout << "Expand I-DID Models" << endl;

			if(argc < 6)
			{
				cout << "USEAGE: expandIDID [input net file] [output net file] [prefixes s,u,oi,oj,ai,aj,mod] [models net files]" << endl;
			}
			else
			{
				string inputNetFile = argv[2];
				string outputNetFile = argv[3];
				//ToDo 
				string prefixesStr = argv[4];
				list<string> modelsNetFiles;

				vector<string> prefixes = argumentToVector(prefixesStr);

				for(int i = 5; i < argc; i++)
				{
					modelsNetFiles.push_back(argv[i]);
				}

				time_t timeVal1=0;
				time(&timeVal1);
				Domain * outputDomain3 = testExpandIDIDModels(inputNetFile, modelsNetFiles, prefixes[0], prefixes[1], prefixes[2], prefixes[3], prefixes[4], prefixes[5], prefixes[6]);
				time_t timeVal2=0;
				time(&timeVal2);
				double duration = (double)(timeVal2-timeVal1)/1000.00;

				cout << "Duration = " << duration << "s" << endl;

				outputDomain3->saveAsNet(outputNetFile);

			}

			//cout << "Expand Complete" << endl;
		}

		//Expand I-DID models
		if(program == "expand")
		{
			//cout << "Expand Domains" << endl;

			if(argc < 6)
			{
				cout << "USEAGE: expand [input net file] [output net file] [Domain Type] [time steps]" << endl;
			}
			else
			{
				string inputNetFile = argv[2];
				string outputNetFile = argv[3];
				string domainType = argv[4];
				int steps = atoi(argv[5]);
				list<string> modelsNetFiles;

				ExpandTimeSteps::IdType IDType;

				if(domainType == "LIMID")
				{
					IDType = ExpandTimeSteps::LIMID;
				}else if(domainType == "DID")
				{
					IDType = ExpandTimeSteps::DID;
				}else if(domainType == "I-DID")
				{
					IDType = ExpandTimeSteps::IDID;
				}

				testExpandDomain(inputNetFile, outputNetFile, IDType, steps);

			}

			//cout << "Expand Complete" << endl;
		}

		if(program == "learnPolicy")
		{
			if(argc < 9)
			{
				cout << "USEAGE: learnPolicy [input csv file] [column headers Model,Observation,Action,State] [observations] [Actions] [States] [input net file] [output net file] [prefixes aj,mod] [tree length] <filled in net file> <merges csv>" << endl;
			}
			else
			{
				string inputCSV = argv[2];

				string colHeadersStr = argv[3];
				string observationsStr = argv[4];
				string actionsStr = argv[5];
				string statesStr = argv[6];

				string inputNet = argv[7];
				string outputNet = argv[8];
				string netPrefixesStr = argv[9];
				vector<string> netPrefixes = argumentToVector(netPrefixesStr);

				int treeLength = atoi(argv[10]);
				string outputFilledNet = "";
				double epsilon = 0.0;
				string mergesRecordFile;

				vector<string> observations = argumentToVector(observationsStr);
				vector<string> actions = argumentToVector(actionsStr);
				vector<string> states = argumentToVector(statesStr);
				vector<string> colHeaders = argumentToVector(colHeadersStr);

				vector<string> joinStates;
				joinStates.push_back(colHeaders[0]);
				joinStates.push_back(colHeaders[1]);

				if(argc > 12)
				{
					outputFilledNet = argv[11];
					epsilon = atof(argv[12]);
					mergesRecordFile = argv[13];
				}

				testLearnPolicy(inputCSV, colHeaders, observations, actions, states, treeLength, inputNet, outputNet, netPrefixes, outputFilledNet, epsilon, mergesRecordFile, joinStates);
			}
		}

		if(program == "calculateAverages")
		{
			if(argc < 5)
			{
				cout << "USEAGE: calculateAverages [mode file/folder] [file/folder] [column header] <min val> <max val>" << endl;
			}
			else
			{
				string mode = argv[2];
				string fileFolder = argv[3];
				string colHeader = argv[4];

				hash_map<string, double> results;

				csvCalculations calculator;
				if(argc > 5)
				{
					double minVal = atof(argv[5]);
					double maxVal = atof(argv[6]);
					if (mode == "folder")
					{
						results = calculator.CalculateAveragesForFolder(fileFolder, colHeader, minVal, maxVal);
					}
					else if(mode == "file")
					{
						//ToDo
					}
					else
					{
						cout << "USEAGE: calculateAverages [mode file/folder] [file/folder] [column header] <min val> <max val>" << endl;
					}
					
				}
				else
				{
					if (mode == "folder")
					{
						results = calculator.CalculateAveragesForFolder(fileFolder, colHeader);
					}
					else if(mode == "file")
					{
						//ToDo
					}
					else
					{
						cout << "USEAGE: calculateAverages [mode file/folder] [file/folder] [column header] <min val> <max val>" << endl;
					}
				}

				list<string> files;

				for(hash_map<string, double>::iterator it = results.begin(); it != results.end(); it++)
				{
					string file = it->first;
					files.push_back(file);
				}

				files.sort();

				for(list<string>::iterator it = files.begin(); it != files.end(); it++)
				{
					double result = results[*it];
					cout << *it << "\t" << result << endl;
				}

				system("pause");
			}
		}

		if(program == "calculateFrequencies")
		{
			if(argc < 5)
			{
				cout << "USEAGE: calculateFrequencies [mode file/folder] [file/folder] [column header]" << endl;
			}
			else
			{
				string mode = argv[2];
				string fileFolder = argv[3];
				string colHeader = argv[4];

				hash_map<string, hash_map<double, int>> results;

				csvCalculations calculator;
				if (mode == "folder")
				{
					results = calculator.CalculateFrequenciesForFolder(fileFolder, colHeader);
				}
				else if(mode == "file")
				{
					//results = calculator.CalculateFrequenciesForFile(fileFolder, colHeader);
				}
				else
				{
					cout << "USEAGE: calculateFrequencies [mode file/folder] [file/folder] [column header]" << endl;
				}

				list<string> files;

				for(hash_map<string, hash_map<double, int>>::iterator it = results.begin(); it != results.end(); it++)
				{
					string file = it->first;
					files.push_back(file);
				}

				files.sort();

				for(list<string>::iterator it = files.begin(); it != files.end(); it++)
				{
					hash_map<double, int> result = results[*it];
					cout << *it << endl;

					//print out hash_map contents
					for(hash_map<double, int>::iterator itr = result.begin(); itr != result.end(); itr++)
					{
						double value = itr->first;
						int freq = itr->second;

						cout << "\t" << value << "\t" << freq << endl;
					}
				}

				system("pause");
			}
		}

		if(program == "mergeFiles")
		{
			if(argc != 5)
			{
				cout << "USEAGE: mergeFiles [input folder] [chars] [output folder]" << endl;
			}
			else
			{
				string inputFolder = argv[2];
				int chars = atoi(argv[3]);
				string outputFolder = argv[4];

				csvCalculations calculator;
				calculator.MergeFilesInFolder(inputFolder, chars, outputFolder);

				system("pause");
			}
		}

		if(program == "policyWriter")
		{
			if(argc != 7)
			{
				cout << "USEAGE: policyWriter [input file] [action prefix] [observation prefix] [output folder] [output file extension]" << endl;
			}
			else
			{
				string inputFile = argv[2];
				string actionPrefix = argv[3];
				string observationPrefix = argv[4];
				string outputFolder = argv[5];
				string fileExt = argv[6];
				testPolicyWriter(inputFile, actionPrefix, observationPrefix, outputFolder, fileExt);
			}
		}

	}

    return 0;
}

vector<string> argumentToVector(string argument)
{
	vector<string> arguments;
	argument = argument.substr(1, argument.length()-2);

	if(argument.length() > 0)
	{
		stringstream ss;
		ss << argument << ",";

		argument = ss.str();

		string temp;
		while (argument.find(",", 0) != string::npos)
		{
			size_t  pos = argument.find(",", 0);
			temp = argument.substr(0, pos);
			argument.erase(0, pos + 1);

			arguments.push_back(temp);
		}
	}

	return arguments;
}


hash_map<string, PPR> testLearnPolicy(string inputCSV, vector<string> colHeaders, vector<string> observations, vector<string> actions, 
										vector<string> states, int treeLength, string inputNet, string outputNet, vector<string> prefixes, 
										string outputFilledNet, double epsilon, string mergesCSVfile, vector<string> joinStates)
{	
	/*vector<string> joinStates;
	joinStates.push_back("Model");
	joinStates.push_back("State");*/

	inputCSV = inputCSV; 
	colHeaders = colHeaders;
	observations = observations;
	actions = actions; 
	states = states;

	time_t timeVal1=0;
	time(&timeVal1);

	BuilderCSV builder;
	hash_map<string, PPR> branches;
	branches = builder.BuildFromCSV(branches, inputCSV, colHeaders[3], actions, colHeaders[2], observations, colHeaders[1], treeLength, states, joinStates);

	//save results into net files
	InputPprIdid ipi;
	ParseListener *pl = new DefaultParseListener();
	Domain * inputID = new Domain(inputNet, pl);
	Domain * domain = ipi.EnterPprModelsIntoIdid(inputID, branches, prefixes[1], prefixes[0], branches.size(), true);

	time_t timeVal2=0;
	time(&timeVal2);
	double duration = (double)(timeVal2-timeVal1)/1000.00;
	cout << "Learning Duration = " << duration << "s" << endl;

	domain->saveAsNet(outputNet);	

	if(outputFilledNet != "")
	{
		timeVal1=0;
		time(&timeVal1);

		int merges = 0;
		//Fill PPRs

		int numTrees = branches.size();
		int numNodesBefore = 0;
		int numNodesAfter = 0;

		for(hash_map<string, PPR>::const_iterator it = branches.begin(); it != branches.end(); it++)
		{
			PPR ppr = it->second;
			numNodesBefore += ppr.GetRootNode()->ToListOfNodes().size();
			merges += ppr.FillInMissingBranches(epsilon);
			numNodesAfter += ppr.GetRootNode()->ToListOfNodes().size();
		}

		inputID = new Domain(inputNet, pl);
		domain = ipi.EnterPprModelsIntoIdid(inputID, branches, prefixes[1], prefixes[0], branches.size(), true);

		timeVal2=0;
		time(&timeVal2);
		duration = (double)(timeVal2-timeVal1)/1000.00;
		cout << "Merge Duration = " << duration << "s" << endl;
		cout << "Merges = " << merges << endl;

		domain->saveAsNet(outputFilledNet);	

		//ToDo Record merges to file
		bool exists = false;
		ifstream f(mergesCSVfile.c_str());
		if (f.good()) 
		{
			exists =  true;
		}
		f.close();

		//ofstream out(fileName.c_str());
		ofstream out;
		out.open (mergesCSVfile.c_str(), std::ofstream::out | std::ofstream::app);

		if(!exists)
		{
			out << "InputFile,NumNodesBefore,NumTrees,Merges,NumNodesAfter" << endl;
		}

		out << inputCSV << ",";
		out << numNodesBefore << ",";
		out << numTrees << ",";
		out << merges << ",";
		out << numNodesAfter << ","<< endl;

		out.close();
	}


	return branches;
}

PolicyNode * testDomainToPolicyTree(string inputDID, string actionPrefix, string observationPrefix)
{
	PolicyTreeBuilder policyBuilder;
	ParseListener *pl = new DefaultParseListener();
	Domain * inputDIDd = new Domain(inputDID, pl);

	PolicyNode * rootNode = policyBuilder.ConvertToPolicyTree(inputDIDd, actionPrefix, observationPrefix);
	return rootNode;
}

Domain * testExpandIDIDModels(string inputIDID, list<string> inputDIDs, string statePrefix, string utilityPrefix, string iObservationPrefix, string jObservationPrefix, string iActionPrefix, string jActionPRefix, string modelPrefix)
{
	ParseListener *pld = new DefaultParseListener();
	Domain * inputIDIDd = new Domain(inputIDID, pld);

	list<Domain *> models;	
	for(list<string>::iterator it = inputDIDs.begin(); it != inputDIDs.end(); it++)
	{
		string inputFile = *it;	
		ParseListener *pl = new DefaultParseListener();
		Domain * inputDID = new Domain(inputFile, pl);
		models.push_back(inputDID);
	}	

	IdidEnterModels enterModels;
	return enterModels.EnterDidModelsIntoIdid(inputIDIDd, models, statePrefix, utilityPrefix, iObservationPrefix, jObservationPrefix, iActionPrefix, jActionPRefix, modelPrefix);
}

void testSimulatorIDID(string inputIDID, list<string> inputDIDs, string statePrefix, string utilityPrefix, string iObservationPrefix, 
						string jObservationPrefix, string iActionPrefix, string jActionPRefix, int numberOfSimulations, string outputFile, 
						bool modelWeight, string modelPrefix)
{
	list<Domain *> models;
	ParseListener *pl = new DefaultParseListener();
	for(list<string>::iterator it = inputDIDs.begin(); it != inputDIDs.end(); it++)
	{
		string inputFile = *it;		
		Domain * inputDID = new Domain(inputFile, pl);
		models.push_back(inputDID);
	}	

	Domain * inputIDIDd = new Domain(inputIDID, pl);

	Simulator simulator;
	list<IDidResult> results = simulator.SimulateIDID(inputIDIDd, models, statePrefix, utilityPrefix, iObservationPrefix, jObservationPrefix, iActionPrefix, jActionPRefix, numberOfSimulations, modelWeight, modelPrefix);

	simulator.SaveIDIDToCsvFile(results, outputFile);
}

void testSimulatorDID(list<string> inputFiles, string statePrefix, string observationPrefix, string actionPrefix, 
						string utilityPrefix, int numberOfSimulations, string outputFile)
{
	list<Domain *> models;
	for(list<string>::iterator it = inputFiles.begin(); it != inputFiles.end(); it++)
	{
		string inputFile = *it;
		ParseListener *pl = new DefaultParseListener();
		Domain * inputDID = new Domain(inputFile, pl);
		models.push_back(inputDID);
	}	

	Simulator simulator;
	list<DidResult> results = simulator.SimulateDID(models, statePrefix, observationPrefix, actionPrefix, utilityPrefix, numberOfSimulations);

	simulator.SaveDIDToCsvFile(results, outputFile);
}

void testIDIDSolver(string inputFile, string model1File, string model2File, int numSteps)
{
	/*ParseListener *pl = new DefaultParseListener();
	Domain * inputIDID = new Domain(inputFile, pl);
	Domain * inputMod0 = new Domain(model1File, pl);
	Domain * inputMod1 = new Domain(model2File, pl);

	list<Domain *> models;
	models.push_back(inputMod0);
	models.push_back(inputMod1);	

	iDIDSolver ids;

	ids.Solve(inputIDID, models, numSteps);

	testPolicyWriter(inputIDID, "C:\\Nets\\Policies\\iDID\\");
	testPolicyWriter(inputMod0, "C:\\Nets\\Policies\\Mod0\\");
	testPolicyWriter(inputMod1, "C:\\Nets\\Policies\\Mod1\\");*/
}

void testPolicyWriter(Domain * inputID, string actionPrefix, string observationPrefix, string outputFolder, string fileExt)
{
	cout << "Write Policy File" << endl;
	PolicyFileWriter pfw;
	pfw.WritePolicyToFile(inputID, actionPrefix, observationPrefix, outputFolder, fileExt);	
}

void testPolicyWriter(string inputFile, string actionPrefix, string observationPrefix, string outputFolder, string fileExt)
{
	ParseListener *pl = new DefaultParseListener();
	Domain *inputID = new Domain(inputFile, pl);
	testPolicyWriter(inputID, actionPrefix, observationPrefix, outputFolder, fileExt);
}

Domain * testExpandDomain(string inputFile, string outputFile, ExpandTimeSteps::IdType idType, int steps)
{
	ExpandTimeSteps ets;
	ParseListener *pl = new DefaultParseListener();
	Domain *inputID = new Domain(inputFile, pl);

	//cout << "Expand Domain" << endl;
	inputID = ets.Expand(inputID, steps, idType);

	//cout << "Update Policies" << endl;
	try
	{
		inputID->compile();
		//inputID->updatePolicies();
	}
	catch (ExceptionHugin * e)
	{
		cout << e->what() << endl;
	}	

	inputID->saveAsNet(outputFile);

	return inputID;
}
